#--------------
# Figure 3. use alluvial plots to depict transition rate
# produce also Table S8.
#==============
load_library = c('bit64','data.table','fst','future.apply','stringr','logger','vroom')
invisible(lapply(load_library, function(x) library(x, character.only=TRUE, quietly= TRUE)))

library(parallel)
library(tidyverse)
library(ggalluvial)
library(ggsci)

round_up = function(x, scale=100, digit=2) as.character(round(x * scale, digit))

bucket = '/Users/bk/Dropbox/project/OP/- posted/2021-06-08_covid_substitution_JAMA/dataverse'

#--------------
# focus on changes
#==============
dt = readLines(file.path(bucket, 'data','processed_data','diff_transition_period.txt'))

# read results
matrix_name = grep('\\t', dt, value=TRUE, invert=TRUE)
dt = grep('\\t', dt, value=TRUE, invert=FALSE)
dt = fread(text=dt)
dt = dt[V1 != '', ]
dt[, outcome := rep(matrix_name, each = 9)]
dt[, V5 := NULL]
names(dt) = c('stat', 'ref', 'diff', 'constant', 'outcome') 

dt = melt(dt[, c('stat', 'diff', 'outcome')], id.var=c('stat', 'outcome'))
dt[, value := as.numeric(value)]

dt[, c('dv1','dv2','state_t0','state_t1','period') := tstrsplit(outcome, '_')]
dt[, period := as.numeric(stringr::str_extract(period, '\\d{1}'))]

all_out = dcast(dt[stat %in% c('b', 'se', 't', 'pvalue', 'll', 'ul'), ], 
	period + state_t0 + state_t1 ~ stat, value.var = c('value'))

all_out[, data := factor(period, levels=c(1,2,3), labels=c('prepandemic','pandemic','latepandemic'))]

# line graphs overtime 
all_out[, condition_t0 := factor(state_t0, levels = c(0,1,2,3),
	labels=paste0('t0: ',c('no treatment','opioid only','therapy only','both')))]
all_out[, condition_t1 := factor(state_t1, levels = c(0,1,2,3),
	labels=paste0('t1: ',c('no treatment','opioid only','therapy only','both')))]

tab_diff = dcast(all_out, data + condition_t0 + condition_t1 ~ ., value.var = c('b','ll','ul'))

tab_diff[, data := factor(data, levels = c('prepandemic','pandemic'))]
tab_diff[, diff := 100 * b]
tab_diff[, lci := 100 * ll]
tab_diff[, uci := 100 * ul]

tab_diff[, positive := ifelse(diff > 0, 'positive','negative')]
tab_diff[,baseline := 0]

tab_diff[data=='prepandemic' & positive=='positive', data_positive := 'pre-pandemic']
tab_diff[data=='prepandemic' & positive=='negative', data_positive := 'pre-pandemic']
tab_diff[data=='pandemic' & positive=='positive', data_positive := 'increase during pandemic']
tab_diff[data=='pandemic' & positive=='negative', data_positive := 'decrease during pandemic']

tab_diff[, condition_t1 := factor(condition_t1, 
	levels=rev(paste0('t1: ',c('no treatment','opioid only','therapy only','both'))),
	labels=rev(c('no treatment','opioid only','therapy only','opioid + therapy')))]
tab_diff[, condition_t0 := factor(condition_t0, 
	levels=paste0('t0: ',c('no treatment','opioid only','therapy only','both')),
	labels=c('no treatment','opioid only','therapy only','opioid + therapy'))]

p = ggplot(tab_diff, aes(y=diff, x=condition_t1,ymin=lci, ymax=uci,
		shape=data_positive,
		fill=data_positive,
		color=data_positive))+
	geom_linerange(aes(ymin=baseline, ymax=diff, x=condition_t1),
		position=position_dodge(width=0.3), color='black', size = 0.3, linetype = 'dotted') +
	geom_pointrange(position=position_dodge(width=0.3), size=0.3, alpha=1) +
	#geom_segment(aes(x=baseline, xend=diff, y=condition_t1, yend=condition_t1,color=positive),
	#		arrow=arrow(length=unit(0.2, 'cm'))) +
	geom_hline(yintercept = 0,color='black',size=0.3)+
	scale_shape_manual(name=NULL,values=c(
		'pre-pandemic'=4,
		'increase during pandemic'=19,
		'decrease during pandemic'=19
		))+
	scale_color_manual(name=NULL,values=c(
		'pre-pandemic'='black',
		'increase during pandemic'='red',
		'decrease during pandemic'='blue'
		))+
	scale_fill_manual(name=NULL,values=c(
		'pre-pandemic'='black',
		'increase during pandemic'='red',
		'decrease during pandemic'='blue'
		))+
	coord_flip()+
	theme_bw() +
	theme(
		panel.grid.minor=element_blank(),
		plot.title=element_text(size = 16, face="bold"),
		plot.subtitle=element_text(size=9, hjust=0,  color="black"),
		plot.caption=element_text(size=8, margin=margin(t=12), color="#7a7d7e")
	) + 
	theme(legend.position = 'bottom')+
	labs(x='Treatment options in the following week',
		y='Mean differences in the transition rates between 2019 and 2020',
		subtitle='Among pain patients visiting in the current week who receive ...')+
	facet_grid(. ~ condition_t0, switch = 'y')+
	theme(strip.background =element_rect(fill="white"))

ggsave(file.path(bucket,'data','figure_3_lolly_pop_transition_CI.eps'), p, width=8, height=4)


tab_diff[, print_cell := paste0(round_up(b, digit=2), ' [', round_up(ll, digit=2), ', ', round_up(ul, digit=2), ']')]

diff_prepandemic = dcast(tab_diff[data == 'prepandemic',], 
	condition_t0 ~ condition_t1, value.var = c('print_cell'))
diff_pandemic = dcast(tab_diff[data == 'pandemic',], 
	condition_t0 ~ condition_t1, value.var = c('print_cell'))

write.csv(diff_prepandemic, 
	file.path(bucket, 'data','table_s8_diff_transition_mat_prepandemic.csv'))
write.csv(diff_pandemic, 
	file.path(bucket, 'data','table_s8_diff_transition_mat_pandemic.csv'))


